#include "rigid.hpp"

#include "comm.h"
#define forward_comm_rigid forward_comm_shake

template<class T>
void rigid_coordinates_listed(cellgrid_t *grid, mpp_t *mpp, T &rigid) {
  // real dtfsq = 0.5 * dtf * dtf;
  FOREACH_CELL(grid, cx, cy, cz, cell) {
    for (int i = 0; i < cell->natom; i++) {
      cell->shake_xuc[i] = 0;
    }
  }
  FOREACH_LOCAL_CELL(grid, cx, cy, cz, icell) {
    /*collect x and xuc of guests, this should be done during nonbonded*/
    vec<real> *guest_x = icell->x + CELL_CAP;
    vec<real> *guest_xuc = icell->shake_xuc + CELL_CAP;
    real *guest_rmass = icell->rmass + CELL_CAP;
    FOREACH_NEIGHBOR(grid, cx, cy, cz, dx, dy, dz, jcell) {
      int did = hdcell(grid->nn, dx, dy, dz);
      vec<real> dcell = jcell->basis - icell->basis;

      for (int jj = icell->first_guest_cell[did]; jj < icell->first_guest_cell[did + 1]; jj++) {
        int j = icell->guest_id[jj];
        guest_x[jj] = dcell + jcell->x[j];
        guest_xuc[jj]=0;
        guest_rmass[jj] = jcell->rmass[j];
      }
    }
    vec<real> tmp_xuc[CELL_CAP + MAX_CELL_GUEST];
    for (int i = 0; i < icell->natom; i ++) {
      tmp_xuc[i] = icell->x[i];
    }
    for (int i = CELL_CAP; i < CELL_CAP + icell->nguest; i ++) {
      tmp_xuc[i] = icell->x[i];
    }
    for (int i = 0; i < icell->natom; i++) {
      int idx[4];
      idx[0] = i;
      idx[1] = icell->rigid[i].id[0];
      idx[2] = icell->rigid[i].id[1];
      idx[3] = icell->rigid[i].id[2];
      switch (icell->rigid[i].type) {
      case RIGID2:
        rigid.template run_static<rigid2>(icell->x, tmp_xuc, icell->rmass, icell->rigid[i], idx);
        break;
      case RIGID3:
        rigid.template run_static<rigid3>(icell->x, tmp_xuc, icell->rmass, icell->rigid[i], idx);
        break;
      case RIGID4:
        rigid.template run_static<rigid4>(icell->x, tmp_xuc, icell->rmass, icell->rigid[i], idx);
        break;
      case RIGID3ANGLE:
        rigid.template run_static<rigid3angle>(icell->x, tmp_xuc, icell->rmass, icell->rigid[i], idx);
        break;
      }
    }
    for (int i = 0; i < icell->natom; i ++) {
      icell->shake_xuc[i] = tmp_xuc[i] - icell->x[i];
    }
    for (int i = CELL_CAP; i < CELL_CAP + icell->nguest; i ++) {
      icell->shake_xuc[i] = tmp_xuc[i] - icell->x[i];
    }
    FOREACH_NEIGHBOR(grid, cx, cy, cz, dx, dy, dz, jcell) {
      int did = hdcell(grid->nn, dx, dy, dz);
      vec<real> dcell = jcell->basis - icell->basis;
      for (int jj = icell->first_guest_cell[did]; jj < icell->first_guest_cell[did + 1]; jj++) {
        int j = icell->guest_id[jj];
        jcell->shake_xuc[j] += guest_xuc[jj];
      }
    }
  }

  reverse_comm_shake(grid, mpp);
  FOREACH_LOCAL_CELL(grid, ii, jj, kk, cell) {
    for (int i = 0; i < cell->natom; i++) {
      cell->x[i] += cell->shake_xuc[i];
    }
  }
  forward_comm_x(grid, mpp);
}
#ifndef __sw_64__

template <class T>
void rigid_force_listed(cellgrid_t *grid, real dtfsq, T &rigid){
  // return;
  real rdtfsq = 1.0 / dtfsq;
  /*clear ghost forces: we had communicated force once!*/
  FOREACH_CELL(grid, cx, cy, cz, icell) {
    if (cx < 0 || cy < 0 || cz < 0 || cx >= grid->nlocal.x || cy >= grid->nlocal.y || cz >= grid->nlocal.z) {
      for (int i = 0; i < icell->natom; i++) {
        icell->f[i] = 0;
      }
    }
  }
  FOREACH_LOCAL_CELL(grid, cx, cy, cz, icell) {
    vec<real> *guest_x = icell->x + CELL_CAP;
    vec<real> *guest_xuc = icell->shake_xuc + CELL_CAP;
    vec<real> *guest_f = icell->f + CELL_CAP;
    FOREACH_NEIGHBOR(grid, cx, cy, cz, dx, dy, dz, jcell) {
      int did = hdcell(grid->nn, dx, dy, dz);
      vec<real> dcell = jcell->basis - icell->basis;
      for (int jj = icell->first_guest_cell[did]; jj < icell->first_guest_cell[did + 1]; jj ++) {
        int j = icell->guest_id[jj];
        guest_x[jj] = dcell + jcell->x[j];
        guest_xuc[jj] = dcell + jcell->shake_xuc[j];
        guest_f[jj] = 0;
      }
    }

    for (int i = 0; i < icell->natom; i ++) {
      int idx[4];
      idx[0] = i;
      idx[1] = icell->rigid[i].id[0];
      idx[2] = icell->rigid[i].id[1];
      idx[3] = icell->rigid[i].id[2];

      switch (icell->rigid[i].type) {
      case RIGID2:
        rigid.template run<rigid2>(icell->f, rdtfsq, icell->x, icell->shake_xuc, icell->rmass, icell->rigid[i], idx);
        break;
      case RIGID3:
        rigid.template run<rigid3>(icell->f, rdtfsq, icell->x, icell->shake_xuc, icell->rmass, icell->rigid[i], idx);
        break;
      case RIGID4:
        rigid.template run<rigid4>(icell->f, rdtfsq, icell->x, icell->shake_xuc, icell->rmass, icell->rigid[i], idx);
        break;
      case RIGID3ANGLE:
        rigid.template run<rigid3angle>(icell->f, rdtfsq, icell->x, icell->shake_xuc, icell->rmass, icell->rigid[i], idx);
        break;
      default:
        break;
      }
    }
    FOREACH_NEIGHBOR(grid, cx, cy, cz, dx, dy, dz, jcell) {
      int did = hdcell(grid->nn, dx, dy, dz);
      for (int jj = icell->first_guest_cell[did]; jj < icell->first_guest_cell[did + 1]; jj ++) {
        int j = icell->guest_id[jj];
        jcell->f[j] += guest_f[jj];
      }
    }
  }
}
#else
template <class T>
extern void rigid_force_sw(cellgrid_t *grid, real dtfsq, T &rigid);
template <class T>
extern void rigid_force_listed_sw(cellgrid_t *grid, real dtfsq, T &rigid);
template <class T>
void rigid_force(cellgrid_t *grid, real dtfsq, T &rigid){
  rigid_force_sw(grid, dtfsq, rigid);
}
template <class T>
void rigid_force_listed(cellgrid_t *grid, real dtfsq, T &rigid){
  rigid_force_listed_sw(grid, dtfsq, rigid);
}
#endif
extern void unconstrained_update(cellgrid_t *grid, real dt, real dtfsq);
template <typename T>
void rigid_setup(cellgrid_t *grid, mpp_t *mpp, real dt, real ftm2v, T &rigid){
  real dtfsq = dt * dt * ftm2v;

  rigid_coordinates_listed(grid, mpp, rigid);
  rigid_post_force_listed(grid, mpp, dt, 0.5*ftm2v, rigid);
  // reverse_comm_f(grid, mpp);
}

template <typename T>
void rigid_post_force(cellgrid_t *grid, mpp_t *mpp, real dt, real ftm2v, T &rigid) {
  real dtfsq = dt * dt * ftm2v;
  unconstrained_update(grid, dt, dtfsq);
  forward_comm_shake(grid, mpp);
  rigid_force_listed(grid, dtfsq, rigid);
  reverse_comm_f(grid, mpp);
}

template <typename T>
void rigid_post_force_listed(cellgrid_t *grid, mpp_t *mpp, real dt, real ftm2v, T &rigid) {
  real dtfsq = dt * dt * ftm2v;
  unconstrained_update(grid, dt, dtfsq);
  forward_comm_shake(grid, mpp);
  rigid_force_listed(grid, dtfsq, rigid);
  reverse_comm_f(grid, mpp);
}
